{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert Tensorflow model to ONNX\n",
    "Tensorflow and ONNX both define their own graph format to represent to model. You can use [tensorflow-onnx](https://github.com/onnx/tensorflow-onnx \"Title\") to export a Tensorflow model to ONNX.\n",
    "\n",
    "We divide the guide into 2 parts: part 1 covers basic conversion and part 2 advanced topics. The following content will be covered in order:\n",
    "1. Procedures to convert tensorflow model\n",
    "   - get tensorflow model\n",
    "   - convert to ONNX\n",
    "   - validate\n",
    "2. Key conceptions\n",
    "   - opset\n",
    "   - data format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Get Tensorflow model\n",
    "Tensorflow uses several file formats to represent a model, such as checkpoint files, graph with weight(called `frozen graph` next) and saved_model, and it has APIs to generate these files, you can find the code snippets in the script [tensorflow_to_onnx_example.py](./assets/tensorflow_to_onnx_example.py)\n",
    "\n",
    "And `tensorflow-onnx` can accept all the three formats to represent a Tensorflow model, **the format \"saved_model\" should be the preference** since it doesn't require the user to specify input and output names of graph.\n",
    "we will cover it in this section and cover the other two in the last section. And also, you could get more detail from `tensorflow-onnx`'s [README](https://github.com/onnx/tensorflow-onnx/blob/master/README.md \"Title\") file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "please wait for a while, because the script will train MNIST from scratch\n",
      "Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz\n",
      "step 0, training accuracy 0.18\n",
      "step 1000, training accuracy 0.98\n",
      "step 2000, training accuracy 0.94\n",
      "step 3000, training accuracy 1\n",
      "step 4000, training accuracy 1\n",
      "test accuracy 0.976\n",
      "save tensorflow in format \"saved_model\"\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import shutil\n",
    "import tensorflow as tf\n",
    "from assets.tensorflow_to_onnx_example import create_and_train_mnist\n",
    "def save_model_to_saved_model(sess, input_tensor, output_tensor):\n",
    "    from tensorflow.saved_model import simple_save\n",
    "    save_path = r\"./output/saved_model\"\n",
    "    if os.path.exists(save_path):\n",
    "        shutil.rmtree(save_path)\n",
    "    simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})\n",
    "\n",
    "print(\"please wait for a while, because the script will train MNIST from scratch\")\n",
    "tf.reset_default_graph()\n",
    "sess_tf, saver, input_tensor, output_tensor = create_and_train_mnist()\n",
    "print(\"save tensorflow in format \\\"saved_model\\\"\")\n",
    "save_model_to_saved_model(sess_tf, input_tensor, output_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Convert to ONNX\n",
    "`tensorflow-onnx` has several entries to convert tensorflow model with different tensorflow formats, this section will cover \"saved_model\" only, \"frozen graph\" and \"checkpoint\" will be covered in [part 2](./TensorflowToOnnx-2.ipynb).\n",
    "\n",
    "Also, `tensorflow-onnx` has exported related python APIs, so users can call them directly on their script instead of command line, also the detail will be covered in [part 2](./TensorflowToOnnx-2.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2019-06-17 07:22:03,871 - INFO - Using tensorflow=1.12.0, onnx=1.5.0, tf2onnx=1.5.1/0c735a\n",
      "2019-06-17 07:22:03,871 - INFO - Using opset <onnx, 7>\n",
      "2019-06-17 07:22:03,989 - INFO - \n",
      "2019-06-17 07:22:04,012 - INFO - Optimizing ONNX model\n",
      "2019-06-17 07:22:04,029 - INFO - After optimization: Add -2 (4->2), Identity -3 (3->0), Transpose -8 (9->1)\n",
      "2019-06-17 07:22:04,031 - INFO - \n",
      "2019-06-17 07:22:04,032 - INFO - Successfully converted TensorFlow model ./output/saved_model to ONNX\n",
      "2019-06-17 07:22:04,044 - INFO - ONNX model is saved at ./output/mnist1.onnx\n"
     ]
    }
   ],
   "source": [
    "# generating mnist.onnx using saved_model\n",
    "!python -m tf2onnx.convert \\\n",
    "        --saved-model ./output/saved_model \\\n",
    "        --output ./output/mnist1.onnx \\\n",
    "        --opset 7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Validate\n",
    "There are several framework can run model in ONNX format, here [ONNXRuntime](https://github.com/microsoft/onnxruntime \"Title\") , opensourced by `Microsoft`, is used to make sure the generated ONNX graph behaves well.\n",
    "The input \"image.npz\" is an image of handwritten \"7\", so the expected classification result of model should be \"7\". "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the expected result is \"7\"\n",
      "the digit is classified as \"7\" in ONNXRruntime\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import onnxruntime as ort\n",
    "\n",
    "img = np.load(\"./assets/image.npz\").reshape([1, 784])  \n",
    "sess_ort = ort.InferenceSession(\"./output/mnist1.onnx\")\n",
    "res = sess_ort.run(output_names=[output_tensor.name], input_feed={input_tensor.name: img})\n",
    "print(\"the expected result is \\\"7\\\"\")\n",
    "print(\"the digit is classified as \\\"%s\\\" in ONNXRruntime\"%np.argmax(res))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Key conceptions\n",
    "This command line should work for most tensorflow models if they are available a saved_model. In some cases you might encounter issues that require extra options.\n",
    "\n",
    "The most important concept is \"**opset** version\": ONNX is an evolving standard, for example it will add more new operations and enhance existing operations, so different opset version will contain different operations and operations may have different behavior. The default version \"tensorflow-onnx\" used is 7 and ONNX supports version 10 now, so if the conversion failed, you may try different version, by command line option \"--opset\", to see if it works.\n",
    "\n",
    "Continue with [part 2](./TensorflowToOnnx-2.ipynb) that explains advanced topics."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
